import torch
import os
import tqdm
import copy
import time

import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import math
from munch import Munch

# some_file.py
import sys
# insert at 1, 0 is the script path (or '' in REPL)

from mli.models.layers import LIBatchNorm2d, LILinear, LIConv2d
from sacred import Experiment
from mli.data import load_data, corrupt_dataset_labels
from mli.optim import get_optimizer
from mli.models import get_activation_function, get_loss_fn, interpolate_state
from mli.models import FCNet, warm_bn
from mli.metrics.gauss_len import compute_avg_gauss_length, compute_avg_gauss_length_bn
from mli.sacred import SlurmFileStorageObserver
from mli.metrics import param_dist

from mli_eval.model.interp import interp_networks
from mli_eval.model.loss import EvalClassifierLoss

EXPERIMENT_NAME = "mli_fcnet"
RUN_DIR = "runs"
ex = Experiment(EXPERIMENT_NAME)
ex.observers.append(SlurmFileStorageObserver(os.path.join(RUN_DIR, EXPERIMENT_NAME)))


@ex.config
def get_config():
    # Data Config
    dset_name = "fmnist"
    generate_data = False
    datasize = 60000
    random_label_proportion = 0.0
    batchsize = 512
    corrupt_on_epoch = -1
    num_classes = 10
    imbalance = True
    shift = 0
    permutation = None if shift==0 else list(range(num_classes, num_classes+shift)) + list(range(num_classes)) + list(range(num_classes+shift, 10))

    # Model Config
    #hsizes = [100, 10]
    depth = 10
    width = 1024
    act_fn = "relu"
    loss_fn = "ce"
    regularization = None
    regscale = 1
    outR_reg = 1
    use_batchnorm = False
    bias = "all_bias"

    
    # Initialization
    init_scale = 1
    bias_scale = 0
    init_type = 'kaiming'
    sym_init = False

    # Optimizer Config
    epochs = 10
    optim_name = "sgd"
    lr = 0.01
    beta = 0.9

    # Misc
    alpha_steps = 50
    cuda = True
    min_loss_threshold = 1
    min_loss_epoch_check = 10000
    log_wdist = True
    

    # Experiment Config
    tag = "fcnet"
    #seed = 0
    save_freq = 1000
    save_state = False
    eval_gl = False
    special_bias_interp = False
    version = 1

@ex.capture
def get_run_id(_run):
    return _run._id

def get_data(d, n, model, margin):
    inputs = []
    outputs = []
    while n:
        x = torch.randn(1, d)
        logits = model(x)
        top_two = torch.topk(logits, 2)
        if top_two.values[0][0]-top_two.values[0][1] < margin:
            continue 
        else:
            inputs.append(x)
            outputs.append(top_two.indices[0][0])
            n -= 1
    inputs = torch.cat(inputs, dim=0)
    outputs = torch.tensor(outputs)

    return inputs, outputs

@ex.capture
def load_data_captured(dset_name, batchsize, datasize, num_classes, generate_data, imbalance, permutation, train=True):
    return load_data(dset_name, batchsize, datasize, num_classes, train=train, imbalance=imbalance, permutation=permutation)

@ex.capture
def get_optimizer_captured(optim_name, lr, beta):
    return get_optimizer(optim_name, lr, beta)


@ex.capture
def get_model(depth, width, act_fn, use_batchnorm, num_classes, bias, init_type):
    if init_type=='default':
        pass
    else:
        if act_fn=='identity':
            init_type = 'xavier'
        elif act_fn=='relu':
            init_type = 'kaiming'
        else:
            raise Exception("Invalid activation given")
    model = FCNet(784, [width]*(depth-1)+[num_classes], act_fn=get_activation_function(act_fn), init_type=init_type, batch_norm=use_batchnorm, bias=bias)
    return model


@ex.capture
def compute_loss(model, out, targets, regularization, loss_fn, regscale, outR_reg):
    loss = get_loss_fn(loss_fn)(out, targets)
    if regularization is None:
        return loss
    elif regularization == "l1":
        l1_reg = 0.0
        scale = 1.0
        for param in model.parameters():
            l1_reg += scale * F.l1_loss(param, torch.zeros_like(param))
        return loss + l1_reg
    elif regularization == "outR":
        reg = 0.0
        for param in model.parameters():
            reg += regscale * (torch.linalg.norm(param) - outR_reg) ** 2
        return loss + reg


def eval_loss(model, loader, cuda, num_classes):
    model.eval()
    loss = 0.0
    acc = 0.0
    u_array = [0]*num_classes
    with torch.no_grad():
        for x, y in loader:
            if cuda:
                x, y = x.view(-1, 784).cuda(), y.cuda()
            logits = model(x)
            logits_softmax = F.softmax(logits, dim=1)
            logits_softmax = torch.sum(logits_softmax, dim=0)
            for i in range(num_classes):
                u_array[i] += logits_softmax[i].item()
            preds = logits.argmax(1)
            acc += (preds == y).float().sum().item()
            b_loss = F.cross_entropy(logits, y)
            loss += b_loss.item() * x.shape[0]
    model.train()
    return loss / len(loader.dataset), acc / len(loader.dataset), u_array


def train_step(model, optimizer, x, y, compute_acc, num_classes):
    optimizer.zero_grad()
    logits = model(x)
    loss = compute_loss(model, logits, y)
    loss.backward()
    optimizer.step()

    ret = {
        "loss": loss.item(),
        "norm": torch.norm(logits, dim=1).mean().item()
    }

    if compute_acc:
        preds = logits.argmax(1)
        acc = (preds == y).float().mean()
        ret["acc"] = acc.item()

        for k in range(num_classes):
            logits_softmax = F.softmax(logits[y==k], dim=1)
            logits_softmax_sum = torch.sum(logits_softmax, dim=0)
            for i in range(num_classes):
                #print(f"u_{i}_{k}")
                ret[f"u_{i}_{k}"] = logits_softmax_sum[i].item()

            log_softmax = -torch.log(logits_softmax)
            log_softmax_sum = torch.sum(log_softmax, dim=0)
            ret[f"loss{k}"] = log_softmax_sum[k].item()

            ret[f"acc{k}"] = torch.sum((preds==k)*(y==k)).item() 
    #print(ret)
    return ret


def train_network(model, loader, class_loaders, optimizer, cfg, _run):
    init_state = copy.deepcopy(model.state_dict())
    checkpoint_dir = os.path.join(RUN_DIR, EXPERIMENT_NAME, get_run_dir())

    start_epoch = 0
    checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pt")
    if os.path.isfile(checkpoint_path):
        print("Found an existing checkpoint. Loading state...")
        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        print("Training from epoch {}".format(start_epoch))
    else:
        # First time on this run
        # Save the initial state
        print("No checkpoint found. Training from scratch.")
        if cfg.save_state:
            init_outfile = os.path.join(checkpoint_dir, "init.pt")
            torch.save({
                "model_state": init_state,
            }, init_outfile)

    losses = []
    for epoch in range(start_epoch, cfg.epochs):
        if cfg.corrupt_on_epoch == epoch:
            corrupt_dataset_labels(loader, cfg.random_label_proportion)
            targetfile = os.path.join(RUN_DIR, EXPERIMENT_NAME, get_run_dir(), "targets")
            np.save(targetfile, loader.dataset.targets.numpy())

        model.train()
        train_metrics = []
        pbar = tqdm.tqdm(loader)
        mean_loss = 0
        for x, y in pbar:
            if cfg.cuda:
                x, y = x.view(-1, 784).cuda(), y.cuda()
            if cfg.loss_fn == "ce":
                train_metrics.append(train_step(model, optimizer, x, y, True, cfg.num_classes))
                mean_loss = np.mean([m["loss"] for m in train_metrics])
                pbar.set_description(
                    "Epoch {:d} train | loss = {:0.6f}, acc = {:0.4f}".format(
                        epoch,
                        mean_loss,
                        np.mean([m["acc"] for m in train_metrics]),
                    )
                )
            elif cfg.loss_fn == "recon":
                train_metrics.append(train_step(model, optimizer, x, x, False, cfg.num_classes))
                mean_loss = np.mean([m["loss"] for m in train_metrics])
                pbar.set_description(
                    "Epoch {:d} train | loss = {:0.6f}".format(
                        epoch,
                        mean_loss,
                    )
                )
            else:
                raise Exception("Invalid loss function given")

        if cfg.loss_fn == "ce":
            _run.log_scalar("train.loss", np.mean([m["loss"] for m in train_metrics]))
            _run.log_scalar("train.acc", np.mean([m["acc"] for m in train_metrics]))
            _run.log_scalar("train.norm", np.mean([m["norm"] for m in train_metrics]))
    
            for k in range(cfg.num_classes):
                
                _run.log_scalar(f"train.loss{k}", np.sum([m[f"loss{k}"] for m in train_metrics]))
                _run.log_scalar(f"train.acc{k}", np.sum([m[f"acc{k}"] for m in train_metrics]))
                for i in range(cfg.num_classes):
                    _run.log_scalar(f"train.u_{i}_{k}", np.sum([m[f"u_{i}_{k}"] for m in train_metrics]))
                    
                if cfg.bias != "no_bias":
                    _run.log_scalar("train.bias"+str(k), list(model.state_dict().values())[-1][k].item())

            '''
            for k in range(cfg.num_classes):
                loss, acc, u_array = eval_loss(model, class_loaders[k], cfg.cuda, cfg.num_classes)
                _run.log_scalar("train.acc"+str(k), acc)
                _run.log_scalar("train.loss"+str(k), loss)

                for i in range(cfg.num_classes):
                    _run.log_scalar(f"train.u_{i}_{k}", u_array[i])
                if cfg.bias != "no_bias":
                    _run.log_scalar("train.bias"+str(k), list(model.state_dict().values())[-1][k].item())
            '''

        elif cfg.loss_fn == "recon":
            _run.log_scalar("train.loss", np.mean([m["loss"] for m in train_metrics]))
        else:
            raise Exception("Invalid loss function given")

        if cfg.log_wdist:
            _run.log_scalar("train.norm_wdist", param_dist(model.state_dict(), init_state, True))
            _run.log_scalar("train.wdist", param_dist(model.state_dict(), init_state, False))

        if epoch > cfg.min_loss_epoch_check and mean_loss > cfg.min_loss_threshold:
            print("Loss threshold not reached by epoch %s" % cfg.min_loss_epoch_check)
            print("Breaking out of training early...")
            break
        if cfg.save_state:
            torch.save({
                "epoch": epoch + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
            }, checkpoint_path)

        if cfg.save_state and cfg.save_freq > 0 and epoch>0 and epoch % cfg.save_freq == 0:
            outfile = os.path.join(checkpoint_dir, "model_{}.pt".format(epoch))
            torch.save({
                "epoch": epoch,
                "model_state": model.state_dict()
            }, outfile)

    final_state = copy.deepcopy(model.state_dict())

    # Save the final state
    final_outfile = os.path.join(checkpoint_dir, "final.pt")
    if cfg.save_state:
        torch.save({
            "model_state": final_state,
        }, final_outfile)
    return init_state, final_state, losses


def get_run_dir():
    rundir = os.getenv("SLURM_JOB_ID")
    if rundir is None:
        rundir = os.getenv("SLURM_ARRAY_JOB_ID")
    if rundir is None:
        rundir = get_run_id()
    return rundir


@ex.automain
def main(seed, _run):
    cfg = Munch.fromDict(_run.config)
    if cfg.permutation:
        for i in cfg.permutation:
            _run.log_scalar("permutation", i)

    loaders = load_data_captured()
    train_loader = loaders[0]
    class_loaders = loaders[1:]
    eval_loader = load_data_captured(train=False)
    model = get_model()
    if cfg.cuda:
        model = model.cuda()
    try:
        optimizer = get_optimizer_captured()(model.parameters())
    except AttributeError:
        optimizer = get_optimizer_captured()(model)

    for m in model.modules():
        if isinstance(m, LILinear):
            m.weight.data = (cfg.init_scale**(1/cfg.depth))*m.weight.data
            if m.bias is not None and cfg.init_type=='kaiming':
                m.bias.data.normal_(0, cfg.bias_scale*math.sqrt(1. / m.in_features))

    if cfg.sym_init:
        model.layers[-1].weight.data[:, :cfg.width//2] = -1*model.layers[-1].weight.data[:, cfg.width//2:]
        model.layers[-3].weight.data[:cfg.width//2, :] = model.layers[-3].weight.data[cfg.width//2:, :]

    # Train network
    init_state, final_state, _ = train_network(
        model, train_loader, class_loaders, optimizer, cfg, _run
    )
    # Evaluate interpolation path of networks
    alphas, metrics = interp_networks(
        model, init_state, final_state, 
        train_loader, [train_loader, eval_loader],
        cfg.alpha_steps, EvalClassifierLoss(), cfg.cuda, cfg.special_bias_interp
    )
    for i in range(len(metrics[0]['loss'])):
        _run.log_scalar("train.interpolation.loss",  metrics[0]['loss'][i])
        _run.log_scalar("train.interpolation.acc", metrics[0]['acc'][i])
        _run.log_scalar("train.interpolation.norm", metrics[0]['norm'][i])
        _run.log_scalar("train.interpolation.alpha", alphas[i])

        _run.log_scalar("eval.interpolation.loss", metrics[1]['loss'][i])
        _run.log_scalar("eval.interpolation.acc", metrics[1]['acc'][i])
        _run.log_scalar("eval.interpolation.norm", metrics[1]['norm'][i])
        _run.log_scalar("eval.interpolation.alpha", alphas[i])
    
    # Evaluate the gauss length of the interpolation path
    if cfg.eval_gl:
        if not model.use_batchnorm:
            # This version is quicker
            avg_gl = compute_avg_gauss_length(model, init_state, final_state, alphas, eval_loader)
        else:
            # Slower but handles batch norm correctly
            avg_gl = compute_avg_gauss_length_bn(model, init_state, final_state, alphas, train_loader, eval_loader,
                                                bn_warm_steps=1)
        _run.log_scalar("gauss_len", avg_gl)

    time.sleep(60)
    bias_interp = "special" if cfg.special_bias_interp else "normal"
    if cfg.use_batchnorm:
        bn='bn'
    else:
        bn=''
    dset_name = cfg.dset_name 

    if not cfg.imbalance:
        dset_name = dset_name + "_balance"
    if cfg.num_classes<10:
        dset_name = dset_name + str(cfg.num_classes) 
        if cfg.shift != 0:
            dset_name = dset_name + 's' + str(cfg.shift)

    metrics_name = f"{dset_name}_{cfg.bias}_{cfg.depth}{bn}_{cfg.init_scale}_{bias_interp}_v{cfg.version}.json"
   

    checkpoint_dir = os.path.join(RUN_DIR, EXPERIMENT_NAME, get_run_dir())
    source_name = os.path.join(checkpoint_dir, "metrics.json")
    dest_name = os.path.join(checkpoint_dir, metrics_name)
    os.rename(source_name, dest_name )